Skip to content

Commit

Permalink
Move MRCNNMaskTarget op to contrib (apache#16486)
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L authored and aaronmarkham committed Oct 16, 2019
1 parent 81a64cc commit 9c90d60
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

/*!
* Copyright (c) 2019 by Contributors
* \file mrcnn_target-inl.h
* \file mrcnn_mask_target-inl.h
* \brief Mask-RCNN target generator
* \author Serge Panev
*/


#ifndef MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
#define MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
#ifndef MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
#define MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_

#include <mxnet/operator.h>
#include <vector>
Expand All @@ -42,13 +42,13 @@ namespace mrcnn_index {
enum ROIAlignOpOutputs {kMask, kMaskClasses};
} // namespace mrcnn_index

struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> {
struct MRCNNMaskTargetParam : public dmlc::Parameter<MRCNNMaskTargetParam> {
int num_rois;
int num_classes;
int mask_size;
int sample_ratio;

DMLC_DECLARE_PARAMETER(MRCNNTargetParam) {
DMLC_DECLARE_PARAMETER(MRCNNMaskTargetParam) {
DMLC_DECLARE_FIELD(num_rois)
.describe("Number of sampled RoIs.");
DMLC_DECLARE_FIELD(num_classes)
Expand All @@ -60,11 +60,11 @@ struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> {
}
};

inline bool MRCNNTargetShape(const NodeAttrs& attrs,
std::vector<mxnet::TShape>* in_shape,
std::vector<mxnet::TShape>* out_shape) {
inline bool MRCNNMaskTargetShape(const NodeAttrs& attrs,
std::vector<mxnet::TShape>* in_shape,
std::vector<mxnet::TShape>* out_shape) {
using namespace mshadow;
const MRCNNTargetParam& param = nnvm::get<MRCNNTargetParam>(attrs.parsed);
const MRCNNMaskTargetParam& param = nnvm::get<MRCNNMaskTargetParam>(attrs.parsed);

CHECK_EQ(in_shape->size(), 4) << "Input:[rois, gt_masks, matches, cls_targets]";

Expand Down Expand Up @@ -98,9 +98,9 @@ inline bool MRCNNTargetShape(const NodeAttrs& attrs,
return true;
}

inline bool MRCNNTargetType(const NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
inline bool MRCNNMaskTargetType(const NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
CHECK_EQ(in_type->size(), 4);
int dtype = (*in_type)[1];
CHECK_NE(dtype, -1) << "Input must have specified type";
Expand All @@ -112,21 +112,21 @@ inline bool MRCNNTargetType(const NodeAttrs& attrs,
}

template<typename xpu>
void MRCNNTargetRun(const MRCNNTargetParam& param, const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s);
void MRCNNMaskTargetRun(const MRCNNMaskTargetParam& param, const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s);

template<typename xpu>
void MRCNNTargetCompute(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
void MRCNNMaskTargetCompute(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
auto s = ctx.get_stream<xpu>();
const auto& p = dmlc::get<MRCNNTargetParam>(attrs.parsed);
MRCNNTargetRun<xpu>(p, inputs, outputs, s);
const auto& p = dmlc::get<MRCNNMaskTargetParam>(attrs.parsed);
MRCNNMaskTargetRun<xpu>(p, inputs, outputs, s);
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
#endif // MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

/*!
* Copyright (c) 2019 by Contributors
* \file mrcnn_target.cu
* \file mrcnn_mask_target.cu
* \brief Mask-RCNN target generator
* \author Serge Panev
*/

#include "./mrcnn_target-inl.h"
#include "./mrcnn_mask_target-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -183,21 +183,21 @@ __device__ void RoIAlignForward(


template<typename DType>
__global__ void MRCNNTargetKernel(const DType *rois,
const DType *gt_masks,
const DType *matches,
const DType *cls_targets,
DType* sampled_masks,
DType* mask_cls,
const int total_out_el,
int batch_size,
int num_classes,
int num_rois,
int num_gtmasks,
int gt_height,
int gt_width,
int mask_size,
int sample_ratio) {
__global__ void MRCNNMaskTargetKernel(const DType *rois,
const DType *gt_masks,
const DType *matches,
const DType *cls_targets,
DType* sampled_masks,
DType* mask_cls,
const int total_out_el,
int batch_size,
int num_classes,
int num_rois,
int num_gtmasks,
int gt_height,
int gt_width,
int mask_size,
int sample_ratio) {
// computing sampled_masks
RoIAlignForward(gt_masks, rois, matches, total_out_el,
num_classes, gt_height, gt_width, mask_size, mask_size,
Expand All @@ -221,8 +221,8 @@ __global__ void MRCNNTargetKernel(const DType *rois,
}

template<>
void MRCNNTargetRun<gpu>(const MRCNNTargetParam& param, const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs, mshadow::Stream<gpu> *s) {
void MRCNNMaskTargetRun<gpu>(const MRCNNMaskTargetParam& param, const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs, mshadow::Stream<gpu> *s) {
const int block_dim_size = kMaxThreadsPerBlock;
using namespace mxnet_op;
using mshadow::Tensor;
Expand All @@ -248,31 +248,31 @@ void MRCNNTargetRun<gpu>(const MRCNNTargetParam& param, const std::vector<TBlob>
dim3 dimGrid = dim3(CUDA_GET_BLOCKS(num_el));
dim3 dimBlock = dim3(block_dim_size);

MRCNNTargetKernel<<<dimGrid, dimBlock, 0, stream>>>
MRCNNMaskTargetKernel<<<dimGrid, dimBlock, 0, stream>>>
(rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_,
out_masks.dptr_, out_mask_cls.dptr_,
num_el, batch_size, param.num_classes, param.num_rois,
num_gtmasks, gt_height, gt_width, param.mask_size, param.sample_ratio);
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNTargetKernel);
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNMaskTargetKernel);
});
}

DMLC_REGISTER_PARAMETER(MRCNNTargetParam);
DMLC_REGISTER_PARAMETER(MRCNNMaskTargetParam);

NNVM_REGISTER_OP(mrcnn_target)
NNVM_REGISTER_OP(_contrib_mrcnn_mask_target)
.describe("Generate mask targets for Mask-RCNN.")
.set_num_inputs(4)
.set_num_outputs(2)
.set_attr_parser(ParamParser<MRCNNTargetParam>)
.set_attr<mxnet::FInferShape>("FInferShape", MRCNNTargetShape)
.set_attr<nnvm::FInferType>("FInferType", MRCNNTargetType)
.set_attr<FCompute>("FCompute<gpu>", MRCNNTargetCompute<gpu>)
.set_attr_parser(ParamParser<MRCNNMaskTargetParam>)
.set_attr<mxnet::FInferShape>("FInferShape", MRCNNMaskTargetShape)
.set_attr<nnvm::FInferType>("FInferType", MRCNNMaskTargetType)
.set_attr<FCompute>("FCompute<gpu>", MRCNNMaskTargetCompute<gpu>)
.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 3D array")
.add_argument("gt_masks", "NDArray-or-Symbol", "Input masks of full image size, a 4D array")
.add_argument("matches", "NDArray-or-Symbol", "Index to a gt_mask, a 2D array")
.add_argument("cls_targets", "NDArray-or-Symbol",
"Value [0, num_class), excluding background class, a 2D array")
.add_arguments(MRCNNTargetParam::__FIELDS__());
.add_arguments(MRCNNMaskTargetParam::__FIELDS__());

} // namespace op
} // namespace mxnet
58 changes: 58 additions & 0 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import itertools
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
from common import with_seed
import unittest

def test_box_nms_op():
Expand Down Expand Up @@ -351,6 +352,63 @@ def test_box_decode_op():
assert_allclose(Y.asnumpy(), np.array([[[-0.0562755, -0.00865743, 0.26227552, 0.42465743], \
[0.13240421, 0.17859563, 0.93759584, 1.1174043 ]]]), atol=1e-5, rtol=1e-5)

@with_seed()
def test_op_mrcnn_mask_target():
if default_context().device_type != 'gpu':
return

num_rois = 2
num_classes = 4
mask_size = 3
ctx = mx.gpu(0)
# (B, N, 4)
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
[3.5, 5.5, 0.9, 2.4]]], ctx=ctx)
gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32)

# (B, N)
matches = mx.nd.array([[2, 0]], ctx=ctx)
# (B, N)
cls_targets = mx.nd.array([[2, 1]], ctx=ctx)

mask_targets, mask_cls = mx.nd.contrib.mrcnn_mask_target(rois, gt_masks, matches, cls_targets,
num_rois=num_rois,
num_classes=num_classes,
mask_size=mask_size)

# Ground truth outputs were generated with GluonCV's target generator
# gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size)
gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]]],
[[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]]]]])

gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]])
gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3))

assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy())
assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy())

if __name__ == '__main__':
import nose
nose.runmodule()
57 changes: 0 additions & 57 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8686,63 +8686,6 @@ def test_rroi_align_value(sampling_ratio=-1):
test_rroi_align_value()
test_rroi_align_value(sampling_ratio=2)

@with_seed()
def test_op_mrcnn_target():
if default_context().device_type != 'gpu':
return

num_rois = 2
num_classes = 4
mask_size = 3
ctx = mx.gpu(0)
# (B, N, 4)
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
[3.5, 5.5, 0.9, 2.4]]], ctx=ctx)
gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32)

# (B, N)
matches = mx.nd.array([[2, 0]], ctx=ctx)
# (B, N)
cls_targets = mx.nd.array([[2, 1]], ctx=ctx)

mask_targets, mask_cls = mx.nd.mrcnn_target(rois, gt_masks, matches, cls_targets,
num_rois=num_rois,
num_classes=num_classes,
mask_size=mask_size)

# Ground truth outputs were generated with GluonCV's target generator
# gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size)
gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]]],
[[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]]]]])

gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]])
gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3))

assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy())
assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy())

@with_seed()
def test_diag():

Expand Down

0 comments on commit 9c90d60

Please sign in to comment.