Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Make mrcnn_mask_target arg mask_size a 2d tuple (#16567)
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L authored and apeforest committed Nov 14, 2019
1 parent c4580ae commit 1cac460
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 22 deletions.
8 changes: 5 additions & 3 deletions src/operator/contrib/mrcnn_target-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,17 @@ namespace mrcnn_index {
struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> {
int num_rois;
int num_classes;
int mask_size;
int sample_ratio;
mxnet::TShape mask_size;

DMLC_DECLARE_PARAMETER(MRCNNTargetParam) {
DMLC_DECLARE_FIELD(num_rois)
.describe("Number of sampled RoIs.");
DMLC_DECLARE_FIELD(num_classes)
.describe("Number of classes.");
DMLC_DECLARE_FIELD(mask_size)
.describe("Size of the pooled masks.");
.set_expect_ndim(2).enforce_nonzero()
.describe("Size of the pooled masks height and width: (h, w).");
DMLC_DECLARE_FIELD(sample_ratio).set_default(2)
.describe("Sampling ratio of ROI align. Set to -1 to use adaptative size.");
}
Expand Down Expand Up @@ -91,7 +92,8 @@ inline bool MRCNNTargetShape(const NodeAttrs& attrs,
CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";

// out: 2 * (B, N, C, MS, MS)
auto oshape = Shape5(batch_size, num_rois, param.num_classes, param.mask_size, param.mask_size);
auto oshape = Shape5(batch_size, num_rois, param.num_classes,
param.mask_size[0], param.mask_size[1]);
out_shape->clear();
out_shape->push_back(oshape);
out_shape->push_back(oshape);
Expand Down
40 changes: 21 additions & 19 deletions src/operator/contrib/mrcnn_target.cu
Original file line number Diff line number Diff line change
Expand Up @@ -183,28 +183,29 @@ __device__ void RoIAlignForward(


template<typename DType>
__global__ void MRCNNTargetKernel(const DType *rois,
const DType *gt_masks,
const DType *matches,
const DType *cls_targets,
DType* sampled_masks,
DType* mask_cls,
const int total_out_el,
int batch_size,
int num_classes,
int num_rois,
int num_gtmasks,
int gt_height,
int gt_width,
int mask_size,
int sample_ratio) {
__global__ void MRCNNMaskTargetKernel(const DType *rois,
const DType *gt_masks,
const DType *matches,
const DType *cls_targets,
DType* sampled_masks,
DType* mask_cls,
const int total_out_el,
int batch_size,
int num_classes,
int num_rois,
int num_gtmasks,
int gt_height,
int gt_width,
int mask_size_h,
int mask_size_w,
int sample_ratio) {
// computing sampled_masks
RoIAlignForward(gt_masks, rois, matches, total_out_el,
num_classes, gt_height, gt_width, mask_size, mask_size,
num_classes, gt_height, gt_width, mask_size_h, mask_size_w,
sample_ratio, num_rois, num_gtmasks, sampled_masks);
// computing mask_cls
int num_masks = batch_size * num_rois * num_classes;
int mask_vol = mask_size * mask_size;
int mask_vol = mask_size_h * mask_size_w;
for (int mask_idx = blockIdx.x; mask_idx < num_masks; mask_idx += gridDim.x) {
int cls_idx = mask_idx % num_classes;
int roi_idx = (mask_idx / num_classes) % num_rois;
Expand Down Expand Up @@ -252,8 +253,9 @@ void MRCNNTargetRun<gpu>(const MRCNNTargetParam& param, const std::vector<TBlob>
(rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_,
out_masks.dptr_, out_mask_cls.dptr_,
num_el, batch_size, param.num_classes, param.num_rois,
num_gtmasks, gt_height, gt_width, param.mask_size, param.sample_ratio);
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNTargetKernel);
num_gtmasks, gt_height, gt_width,
param.mask_size[0], param.mask_size[1], param.sample_ratio);
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNMaskTargetKernel);
});
}

Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,63 @@ def test_box_decode_op():
assert_allclose(Y.asnumpy(), np.array([[[-0.0562755, -0.00865743, 0.26227552, 0.42465743], \
[0.13240421, 0.17859563, 0.93759584, 1.1174043 ]]]), atol=1e-5, rtol=1e-5)

@with_seed()
def test_op_mrcnn_mask_target():
if default_context().device_type != 'gpu':
return

num_rois = 2
num_classes = 4
mask_size = (3, 3)
ctx = mx.gpu(0)
# (B, N, 4)
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
[3.5, 5.5, 0.9, 2.4]]], ctx=ctx)
gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32)

# (B, N)
matches = mx.nd.array([[2, 0]], ctx=ctx)
# (B, N)
cls_targets = mx.nd.array([[2, 1]], ctx=ctx)

mask_targets, mask_cls = mx.nd.contrib.mrcnn_mask_target(rois, gt_masks, matches, cls_targets,
num_rois=num_rois,
num_classes=num_classes,
mask_size=mask_size)

# Ground truth outputs were generated with GluonCV's target generator
# gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size)
gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]]],
[[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]]]]])

gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]])
gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3))

assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy())
assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy())

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 1cac460

Please sign in to comment.