Skip to content

Commit

Permalink
[RELAY][OP] Faster-RCNN Proposal OP (#2725)
Browse files Browse the repository at this point in the history
* [RELAY][OP] Proposal

* Fix

* Fix test
  • Loading branch information
vinx13 authored and Laurawly committed Mar 5, 2019
1 parent c8a3a59 commit fe06049
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 3 deletions.
38 changes: 38 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,44 @@ struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
}
};

/*! \brief Attributes used in proposal operators */
struct ProposalAttrs : public tvm::AttrsNode<ProposalAttrs> {
Array<IndexExpr> scales;
Array<IndexExpr> ratios;
int feature_stride;
double threshold;
int rpn_pre_nms_top_n;
int rpn_post_nms_top_n;
int rpn_min_size;
bool iou_loss;

TVM_DECLARE_ATTRS(ProposalAttrs, "relay.attrs.ProposalAttrs") {
TVM_ATTR_FIELD(scales)
.set_default(Array<IndexExpr>({4.0f, 8.0f, 16.0f, 32.0f}))
.describe("Used to generate anchor windows by enumerating scales");
TVM_ATTR_FIELD(ratios)
.set_default(Array<IndexExpr>({0.5f, 1.0f, 2.0f}))
.describe("Used to generate anchor windows by enumerating ratios");
TVM_ATTR_FIELD(feature_stride)
.set_default(16)
.describe(
"The size of the receptive field each unit in the convolution layer of the rpn,"
"for example the product of all stride's prior to this layer.");
TVM_ATTR_FIELD(threshold)
.set_default(0.7)
.describe(
"IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)");
TVM_ATTR_FIELD(rpn_pre_nms_top_n)
.set_default(6000)
.describe("Number of top scoring boxes to apply NMS. -1 to use all boxes");
TVM_ATTR_FIELD(rpn_post_nms_top_n)
.set_default(300)
.describe("Number of top scoring boxes to keep after applying NMS to RPN proposals");
TVM_ATTR_FIELD(rpn_min_size).set_default(16).describe("Minimum height or width in proposal");
TVM_ATTR_FIELD(iou_loss).set_default(false).describe("Usage of IoU Loss");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,20 @@ def _mx_roi_align(inputs, attrs):
return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs)


def _mx_proposal(inputs, attrs):
new_attrs = {}
new_attrs["scales"] = attrs.get_float_tuple("scales", (4.0, 8.0, 16.0, 32.0))
new_attrs["ratios"] = attrs.get_float_tuple("ratios", (0.5, 1.0, 2.0))
new_attrs["feature_stride"] = attrs.get_int("feature_stride", 16)
new_attrs["threshold"] = attrs.get_float("threshold", 0.7)
new_attrs["rpn_pre_nms_top_n"] = attrs.get_int("rpn_pre_nms_top_n", 6000)
new_attrs["rpn_post_nms_top_n"] = attrs.get_int("rpn_post_nms_top_n", 300)
new_attrs["rpn_min_size"] = attrs.get_int("rpn_min_size", 16)
new_attrs["iou_loss"] = attrs.get_bool("iou_loss", False)
assert not attrs.get_bool("output_score", False), "proposal doesn't support output score"
return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -466,6 +480,8 @@ def _mx_roi_align(inputs, attrs):
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
"_contrib_MultiBoxDetection" : _mx_multibox_detection,
"_contrib_ROIAlign" : _mx_roi_align,
"_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
Expand Down
28 changes: 27 additions & 1 deletion python/tvm/relay/op/vision/_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=invalid-name, unused-argument
"""Faster R-CNN and Mask R-CNN operations."""
import topi
from topi.util import get_const_tuple
from topi.util import get_const_tuple, get_float_tuple, get_const_int
from .. import op as reg
from ..op import OpPattern

Expand All @@ -21,3 +21,29 @@ def schedule_roi_align(_, outs, target):
return topi.generic.vision.schedule_roi_align(outs)

reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("vision.proposal")
def compute_proposal(attrs, inputs, _, target):
"""Compute definition of proposal"""
scales = get_float_tuple(attrs.scales)
ratios = get_float_tuple(attrs.ratios)
feature_stride = attrs.feature_stride
threshold = attrs.threshold
rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
rpn_min_size = attrs.rpn_min_size
iou_loss = bool(get_const_int(attrs.iou_loss))
with target:
return [
topi.vision.rcnn.proposal(inputs[0], inputs[1], inputs[2], scales, ratios,
feature_stride, threshold, rpn_pre_nms_top_n,
rpn_post_nms_top_n, rpn_min_size, iou_loss)
]

@reg.register_schedule("vision.proposal")
def schedule_proposal(_, outs, target):
"""Schedule definition of proposal"""
with target:
return topi.generic.schedule_proposal(outs)

reg.register_pattern("vision.proposal", OpPattern.OPAQUE)
60 changes: 60 additions & 0 deletions python/tvm/relay/op/vision/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,63 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
"""
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)


def proposal(cls_prob,
bbox_pred,
im_info,
scales,
ratios,
feature_stride,
threshold,
rpn_pre_nms_top_n,
rpn_post_nms_top_n,
rpn_min_size,
iou_loss):
"""Proposal operator.
Parameters
----------
cls_prob : relay.Expr
4-D tensor with shape [batch, 2 * num_anchors, height, width].
bbox_pred : relay.Expr
4-D tensor with shape [batch, 4 * num_anchors, height, width].
im_info : relay.Expr
2-D tensor with shape [batch, 3]. The last dimension should be in format of
[im_height, im_width, im_scale]
scales : list/tuple of float
Scales of anchor windoes.
ratios : list/tuple of float
Ratios of anchor windoes.
feature_stride : int
The size of the receptive field each unit in the convolution layer of the rpn, for example
the product of all stride's prior to this layer.
threshold : float
Non-maximum suppression threshold.
rpn_pre_nms_top_n : int
Number of top scoring boxes to apply NMS. -1 to use all boxes.
rpn_post_nms_top_n : int
Number of top scoring boxes to keep after applying NMS to RPN proposals.
rpn_min_size : int
Minimum height or width in proposal.
iou_loss : bool
Usage of IoU loss.
Returns
-------
output : relay.Expr
2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
[batch_index, w_start, h_start, w_end, h_end].
"""
return _make.proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss)
67 changes: 67 additions & 0 deletions src/relay/op/vision/rcnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,72 @@ RELAY_REGISTER_OP("vision.roi_align")
.set_support_level(5)
.add_type_rel("ROIAlign", ROIAlignRel);

TVM_REGISTER_NODE_TYPE(ProposalAttrs);

bool ProposalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
auto proposal_attrs = attrs.as<ProposalAttrs>();
CHECK_EQ(types.size(), 4);
const auto* cls_prob = types[0].as<TensorTypeNode>();
const auto* bbox_pred = types[1].as<TensorTypeNode>();
const auto* im_info = types[2].as<TensorTypeNode>();

if (!cls_prob || !bbox_pred || !im_info) {
return false;
}

CHECK_EQ(cls_prob->shape.size(), 4U)
<< "The dimension of class probability should be 4, but received " << cls_prob->shape.size();
CHECK_EQ(bbox_pred->shape.size(), 4U)
<< "The dimension of box prediction should be 4, but received " << bbox_pred->shape.size();
CHECK_EQ(im_info->shape.size(), 2U)
<< "The dimension of image info should be 2, but received " << im_info->shape.size();
CHECK(reporter->AssertEQ(im_info->shape[1], 3));

auto batch = cls_prob->shape[0];

std::vector<IndexExpr> oshape(
{batch * proposal_attrs->rpn_post_nms_top_n, 5});
reporter->Assign(types[3], TensorTypeNode::make(oshape, cls_prob->dtype));
return true;
}

Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr> scales,
Array<IndexExpr> ratios, int feature_stride, double threshold,
int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size,
bool iou_loss) {
auto attrs = make_node<ProposalAttrs>();
attrs->scales = scales;
attrs->ratios = ratios;
attrs->feature_stride = feature_stride;
attrs->threshold = threshold;
attrs->rpn_pre_nms_top_n = rpn_pre_nms_top_n;
attrs->rpn_post_nms_top_n = rpn_post_nms_top_n;
attrs->rpn_min_size = rpn_min_size;
attrs->iou_loss = iou_loss;
static const Op& op = Op::Get("vision.proposal");
return CallNode::make(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.vision._make.proposal")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 11>(MakeProposal, args, rv);
});

RELAY_REGISTER_OP("vision.proposal")
.describe(R"code(Generate region proposals via RPN.
- **cls_prob**: 4-D with shape [batch, 2 * num_anchors, height, width].
- **bbox_pred**: 4-D with shape [batch, 4 * num_anchors, height, width].
- **im_info**: 2-D with shape [batch, 3].
- **out**: 2-D with shape [batch * rpn_post_nms_top_n, 5].
)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("cls_prob", "Tensor", "Score of how likely proposal is object")
.add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals")
.add_argument("im_info", "Tensor", "Image size and scale")
.set_support_level(5)
.add_type_rel("Proposal", ProposalRel);

} // namespace relay
} // namespace tvm
67 changes: 67 additions & 0 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,72 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_
verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2)


def test_proposal():
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
cls_prob = relay.var("cls_prob", relay.ty.TensorType(np_cls_prob.shape, "float32"))
bbox_pred = relay.var("bbox_pred", relay.ty.TensorType(np_bbox_pred.shape, "float32"))
im_info = relay.var("im_info", relay.ty.TensorType(np_im_info.shape, "float32"))
z = relay.vision.proposal(cls_prob, bbox_pred, im_info, **attrs)
zz = relay.ir_pass.infer_type(z)

assert zz.checked_type == relay.ty.TensorType(np_out.shape, "float32")

func = relay.Function([cls_prob, bbox_pred, im_info], z)
func = relay.ir_pass.infer_type(func)
for target in ['cuda']:
if not tvm.module.enabled(target):
print("Skip test because %s is not enabled." % target)
continue
ctx = tvm.context(target, 0)
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info)
tvm.testing.assert_allclose(op_res1.asnumpy(), np_out, rtol=1e-4)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info)
tvm.testing.assert_allclose(op_res2.asnumpy(), np_out, rtol=1e-4)

attrs = {
'scales': (0.5,),
'ratios': (0.5,),
'feature_stride': 16,
'iou_loss': False,
'rpn_min_size': 16,
'threshold': 0.7,
'rpn_pre_nms_top_n': 200,
'rpn_post_nms_top_n': 4,
}

np_cls_prob = np.array([[
[[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
[[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]]
]], dtype='float32')
np_bbox_pred = np.array([[
[[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]],
[[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]],
[[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
[[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
]], dtype='float32')
np_im_info = np.array([[48., 48., 1.]], dtype='float32')
np_out = np.array([
[0., 0., 2.8451548,28.38012, 18.154846],
[0., 0., 15.354933, 41.96971, 41.245064],
[0., 18.019852, 1.0538368, 51.98015, 25.946163],
[0., 27.320923, -1.266357, 55., 24.666357]
], dtype='float32')


verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)

np_out = np.array([
[ 0., -5.25, -2.5, 21.75, 19.],
[ 0., 11.25, -2., 37.25, 18.5],
[ 0., 26.849998, -2.3000002, 53.45, 18.6],
[ 0., -4.95, 13.799999, 22.25, 35.5]
], dtype='float32')
attrs['iou_loss'] = True
verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)


def test_yolo_reorg_infer_shape():
def verify_yolo_reorg(shape, stride, out_shape):
x = relay.var("x", relay.TensorType(shape, "float32"))
Expand Down Expand Up @@ -347,5 +413,6 @@ def verify_yolo_reorg(shape, stride):
test_multibox_transform_loc()
test_nms()
test_roi_align()
test_proposal()
test_yolo_reorg_infer_shape()
test_yolo_reorg()
4 changes: 2 additions & 2 deletions topi/tests/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_roi_align():
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
cls_prob = tvm.placeholder(np_cls_prob.shape)
bbox_pred = tvm.placeholder(np_bbox_pred.shape)
im_info = tvm.placeholder(np_im_info.shape, dtype='int32')
im_info = tvm.placeholder(np_im_info.shape)

def check_device(device):
ctx = tvm.context(device, 0)
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_proposal():
[[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
[[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
]], dtype='float32')
np_im_info = np.array([[48, 48, 1]], dtype='int32')
np_im_info = np.array([[48., 48., 1.]], dtype='float32')
np_out = np.array([
[0., 0., 2.8451548,28.38012, 18.154846],
[0., 0., 15.354933, 41.96971, 41.245064],
Expand Down

0 comments on commit fe06049

Please sign in to comment.