Skip to content
Merged
25 changes: 25 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@ struct MultiBoxTransformLocAttrs : public tvm::AttrsNode<MultiBoxTransformLocAtt
bool clip;
double threshold;
Array<IndexExpr> variances;
bool keep_background;

TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, "relay.attrs.MultiBoxTransformLocAttrs") {
TVM_ATTR_FIELD(clip).set_default(true).describe("Clip out-of-boundary boxes.");
TVM_ATTR_FIELD(threshold).set_default(0.01).describe("Threshold to be a positive prediction.");
TVM_ATTR_FIELD(variances)
.set_default(Array<IndexExpr>({0.1f, 0.1f, 0.2f, 0.2f}))
.describe("Variances to be decoded from box regression output.");
TVM_ATTR_FIELD(keep_background)
.set_default(false)
.describe("Whether to keep boxes detected as background or not");
}
};

Expand Down Expand Up @@ -129,6 +133,27 @@ struct AllClassNonMaximumSuppressionAttrs
}
};

/*! \brief Attributes used in regular_non_maximum_suppression operator */
struct RegularNonMaximumSuppressionAttrs
: public tvm::AttrsNode<RegularNonMaximumSuppressionAttrs> {
int32_t max_detections_per_class;
int32_t max_detections;
int32_t num_classes;
double iou_threshold;
double score_threshold;

TVM_DECLARE_ATTRS(RegularNonMaximumSuppressionAttrs,
"relay.attrs.RegularNonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_detections_per_class)
.describe("The maxinum number of output selected boxes per class.");
TVM_ATTR_FIELD(max_detections).describe("The maxinum number of output selected boxes.");
TVM_ATTR_FIELD(num_classes).describe("The number of classes without background.");
TVM_ATTR_FIELD(iou_threshold).describe("The IoU threshold for box the overlap test.");
TVM_ATTR_FIELD(score_threshold)
.describe("Score threshold to filter out low score boxes early.");
}
};

/*! \brief Attributes used in roi_align operators */
struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
Array<IndexExpr> pooled_size;
Expand Down
45 changes: 32 additions & 13 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3443,12 +3443,7 @@ def convert_detection_postprocess(self, op):
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecoder(flexbuffer).decode()

if "use_regular_nms" in custom_options:
if custom_options["use_regular_nms"]:
raise tvm.error.OpAttributeUnImplemented(
"use_regular_nms=True is not yet supported for operator "
"TFLite_Detection_PostProcess."
)
use_regular_nms = "use_regular_nms" in custom_options and custom_options["use_regular_nms"]

inputs = self.get_input_tensors(op)
assert len(inputs) == 3, "inputs length should be 3"
Expand Down Expand Up @@ -3481,15 +3476,14 @@ def convert_detection_postprocess(self, op):
input_zero_point=inputs[2].qnn_params["zero_point"],
)

# reshape the cls_pred and loc_prob tensors so
# they can be consumed by multibox_transform_loc
cls_pred = _op.transpose(cls_pred, [0, 2, 1])
# loc_prob coords are in yxhw format
# need to convert to xywh
loc_coords = _op.split(loc_prob, 4, axis=2)
loc_prob = _op.concatenate(
[loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
)
# reshape loc_prob tensor so is can be consumed by
# multibox_transform_loc
loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes * 4])

# anchor coords are in yxhw format
Expand All @@ -3511,13 +3505,41 @@ def convert_detection_postprocess(self, op):
# attributes for multibox_transform_loc
multibox_transform_loc_attrs = {}
multibox_transform_loc_attrs["clip"] = False
multibox_transform_loc_attrs["threshold"] = custom_options["nms_score_threshold"]
multibox_transform_loc_attrs["threshold"] = (
0.0 if use_regular_nms else custom_options["nms_score_threshold"]
)
multibox_transform_loc_attrs["variances"] = (
1 / custom_options["x_scale"],
1 / custom_options["y_scale"],
1 / custom_options["w_scale"],
1 / custom_options["h_scale"],
)
multibox_transform_loc_attrs["keep_background"] = use_regular_nms

ret = _op.vision.multibox_transform_loc(
# reshape cls_pred so it can be consumed by
# multibox_transform_loc
_op.transpose(cls_pred, [0, 2, 1]),
loc_prob,
anchor_expr,
**multibox_transform_loc_attrs,
)

if use_regular_nms:
# box coordinates need to be converted from ltrb to (ymin, xmin, ymax, xmax)
_, transformed_boxes = _op.split(ret[0], (2,), axis=2)
box_l, box_t, box_r, box_b = _op.split(transformed_boxes, 4, axis=2)
transformed_boxes = _op.concatenate([box_t, box_l, box_b, box_r], axis=2)

return _op.vision.regular_non_max_suppression(
boxes=transformed_boxes,
scores=cls_pred,
max_detections_per_class=custom_options["detections_per_class"],
max_detections=custom_options["max_detections"],
num_classes=custom_options["num_classes"],
iou_threshold=custom_options["nms_iou_threshold"],
score_threshold=custom_options["nms_score_threshold"],
)

# attributes for non_max_suppression
non_max_suppression_attrs = {}
Expand All @@ -3528,9 +3550,6 @@ def convert_detection_postprocess(self, op):
non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
non_max_suppression_attrs["invalid_to_bottom"] = False

ret = _op.vision.multibox_transform_loc(
cls_pred, loc_prob, anchor_expr, **multibox_transform_loc_attrs
)
ret = _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **non_max_suppression_attrs)
ret = _op.vision.get_valid_counts(ret, 0)
valid_count = ret[0]
Expand Down
34 changes: 33 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,10 @@ def _compute_multibox_transform_loc(attrs, inputs, _):
clip = bool(get_const_int(attrs.clip))
threshold = get_const_float(attrs.threshold)
variances = get_float_tuple(attrs.variances)
return topi_compute(inputs[0], inputs[1], inputs[2], clip, threshold, variances)
keep_background = bool(get_const_int(attrs.keep_background))
return topi_compute(
inputs[0], inputs[1], inputs[2], clip, threshold, variances, keep_background
)

return _compute_multibox_transform_loc

Expand Down Expand Up @@ -1316,6 +1319,35 @@ def all_class_nms_strategy(attrs, inputs, out_type, target):
return strategy


def wrap_compute_regular_nms(topi_compute):
"""wrap regular nms topi compute"""

def _compute_nms(attrs, inputs, out_type):
return topi_compute(
inputs[0],
inputs[1],
attrs.max_detections_per_class,
attrs.max_detections,
attrs.num_classes,
attrs.iou_threshold,
attrs.score_threshold,
)

return _compute_nms


@override_native_generic_func("regular_non_max_suppression_strategy")
def regular_nms_strategy(attrs, inputs, out_type, target):
"""regular nms generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_regular_nms(topi.vision.regular_non_max_suppression),
wrap_topi_schedule(topi.generic.schedule_nms),
name="regular_nms.generic",
)
return strategy


# roi_align
def wrap_compute_roi_align(topi_compute):
"""wrap roi_align topi compute"""
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
reg.register_strategy("vision.all_class_non_max_suppression", strategy.all_class_nms_strategy)
reg.register_pattern("vision.all_class_non_max_suppression", OpPattern.OPAQUE)

reg.register_strategy("vision.regular_non_max_suppression", strategy.regular_nms_strategy)
reg.register_pattern("vision.regular_non_max_suppression", OpPattern.OPAQUE)


@script
def _get_valid_counts_shape_func(data_shape):
Expand Down Expand Up @@ -122,6 +125,33 @@ def all_class_nms_shape_func(attrs, inputs, _):
return _all_class_nms_shape_func_tf(inputs[0], inputs[1])


@script
def _regular_nms_shape_func(boxes_shape, scores_shape, attrs):
out_boxes_shape = output_tensor((3,), "int64")
out_classes_shape = output_tensor((2,), "int64")
out_scores_shape = output_tensor((2,), "int64")
out_num_detections_shape = output_tensor((1,), "int64")

out_boxes_shape[0] = boxes_shape[0]
out_boxes_shape[1] = int64(attrs.max_detections)
out_boxes_shape[2] = int64(4)

out_classes_shape[0] = boxes_shape[0]
out_classes_shape[1] = int64(attrs.max_detections)

out_scores_shape[0] = boxes_shape[0]
out_scores_shape[1] = int64(attrs.max_detections)

out_num_detections_shape[0] = boxes_shape[0]

return out_boxes_shape, out_classes_shape, out_scores_shape, out_num_detections_shape


@reg.register_shape_func("vision.regular_non_max_suppression", False)
def regular_nms_shape_func(attrs, inputs, _):
return _regular_nms_shape_func(inputs[0], inputs[1], attrs)


@script
def _roi_align_shape_func_nchw(data_shape, rois_shape, pooled_size):
out = output_tensor((4,), "int64")
Expand Down
22 changes: 20 additions & 2 deletions python/tvm/relay/op/vision/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ def multibox_prior(


def multibox_transform_loc(
cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)
cls_prob,
loc_pred,
anchor,
clip=True,
threshold=0.01,
variances=(0.1, 0.1, 0.2, 0.2),
keep_background=False,
):
"""Location transformation for multibox detection

Expand All @@ -77,10 +83,22 @@ def multibox_transform_loc(
variances : Tuple of float, optional
variances to be decoded from box regression output.

keep_background : boolean, optional
Whether to keep boxes detected as background or not.

Returns
-------
ret : tuple of tvm.relay.Expr
"""
return expr.TupleWrapper(
_make.multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances), 2
_make.multibox_transform_loc(
cls_prob,
loc_pred,
anchor,
clip,
threshold,
variances,
keep_background,
),
2,
)
59 changes: 59 additions & 0 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,62 @@ def all_class_non_max_suppression(
return expr.TupleWrapper(out, 2)

return expr.TupleWrapper(out, 3)


def regular_non_max_suppression(
boxes,
scores,
max_detections_per_class,
max_detections,
num_classes,
iou_threshold,
score_threshold,
):
"""Regular non-maximum suppression operator for object detection, corresponding to TFLite's
regular NMS. NMS is performed for each class separately.

Parameters
----------
boxes : relay.Expr
3-D tensor with shape (batch_size, num_boxes, 4). The four values in boxes
encode (ymin, xmin, ymax, xmax) coordinates of a box

scores: relay.Expr
3-D tensor with shape (batch_size, num_boxes, num_classes_with_background)

max_detections_per_class : int
The maxinum number of output selected boxes per class

max_detections : int
The maxinum number of output selected boxes

num_classes : int
The number of classes without background

iou_threshold : float
IoU test threshold

score_threshold : float
Score threshold to filter out low score boxes early

Returns
-------
out : relay.Tuple
The output is a relay.Tuple of four tensors. The first is `detection_boxes` of size
`(batch_size, max_detections , 4)`, the second is `detection_classes` of size
`(batch_size, max_detections)`, the third is `detection_scores` of size
`(batch_size, max_detections)`, and the fourth is `num_detections` of size `(batch_size,)`
representing the total number of selected boxes per batch.
"""
return expr.TupleWrapper(
_make.regular_non_max_suppression(
boxes,
scores,
max_detections_per_class,
max_detections,
num_classes,
iou_threshold,
score_threshold,
),
4,
)
Loading