Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Move mrcnn_mask_target op to contrib #16486

Merged
merged 1 commit into from
Oct 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

/*!
* Copyright (c) 2019 by Contributors
* \file mrcnn_target-inl.h
* \file mrcnn_mask_target-inl.h
* \brief Mask-RCNN target generator
* \author Serge Panev
*/


#ifndef MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
#define MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
#ifndef MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
#define MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_

#include <mxnet/operator.h>
#include <vector>
Expand All @@ -42,13 +42,13 @@ namespace mrcnn_index {
enum ROIAlignOpOutputs {kMask, kMaskClasses};
} // namespace mrcnn_index

struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> {
struct MRCNNMaskTargetParam : public dmlc::Parameter<MRCNNMaskTargetParam> {
int num_rois;
int num_classes;
int mask_size;
int sample_ratio;

DMLC_DECLARE_PARAMETER(MRCNNTargetParam) {
DMLC_DECLARE_PARAMETER(MRCNNMaskTargetParam) {
DMLC_DECLARE_FIELD(num_rois)
.describe("Number of sampled RoIs.");
DMLC_DECLARE_FIELD(num_classes)
Expand All @@ -60,11 +60,11 @@ struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> {
}
};

inline bool MRCNNTargetShape(const NodeAttrs& attrs,
std::vector<mxnet::TShape>* in_shape,
std::vector<mxnet::TShape>* out_shape) {
inline bool MRCNNMaskTargetShape(const NodeAttrs& attrs,
std::vector<mxnet::TShape>* in_shape,
std::vector<mxnet::TShape>* out_shape) {
using namespace mshadow;
const MRCNNTargetParam& param = nnvm::get<MRCNNTargetParam>(attrs.parsed);
const MRCNNMaskTargetParam& param = nnvm::get<MRCNNMaskTargetParam>(attrs.parsed);

CHECK_EQ(in_shape->size(), 4) << "Input:[rois, gt_masks, matches, cls_targets]";

Expand Down Expand Up @@ -98,9 +98,9 @@ inline bool MRCNNTargetShape(const NodeAttrs& attrs,
return true;
}

inline bool MRCNNTargetType(const NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
inline bool MRCNNMaskTargetType(const NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
CHECK_EQ(in_type->size(), 4);
int dtype = (*in_type)[1];
CHECK_NE(dtype, -1) << "Input must have specified type";
Expand All @@ -112,21 +112,21 @@ inline bool MRCNNTargetType(const NodeAttrs& attrs,
}

template<typename xpu>
void MRCNNTargetRun(const MRCNNTargetParam& param, const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s);
void MRCNNMaskTargetRun(const MRCNNMaskTargetParam& param, const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s);

template<typename xpu>
void MRCNNTargetCompute(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
void MRCNNMaskTargetCompute(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
auto s = ctx.get_stream<xpu>();
const auto& p = dmlc::get<MRCNNTargetParam>(attrs.parsed);
MRCNNTargetRun<xpu>(p, inputs, outputs, s);
const auto& p = dmlc::get<MRCNNMaskTargetParam>(attrs.parsed);
MRCNNMaskTargetRun<xpu>(p, inputs, outputs, s);
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_
#endif // MXNET_OPERATOR_CONTRIB_MRCNN_MASK_TARGET_INL_H_
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

/*!
* Copyright (c) 2019 by Contributors
* \file mrcnn_target.cu
* \file mrcnn_mask_target.cu
* \brief Mask-RCNN target generator
* \author Serge Panev
*/

#include "./mrcnn_target-inl.h"
#include "./mrcnn_mask_target-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -183,21 +183,21 @@ __device__ void RoIAlignForward(


template<typename DType>
__global__ void MRCNNTargetKernel(const DType *rois,
const DType *gt_masks,
const DType *matches,
const DType *cls_targets,
DType* sampled_masks,
DType* mask_cls,
const int total_out_el,
int batch_size,
int num_classes,
int num_rois,
int num_gtmasks,
int gt_height,
int gt_width,
int mask_size,
int sample_ratio) {
__global__ void MRCNNMaskTargetKernel(const DType *rois,
const DType *gt_masks,
const DType *matches,
const DType *cls_targets,
DType* sampled_masks,
DType* mask_cls,
const int total_out_el,
int batch_size,
int num_classes,
int num_rois,
int num_gtmasks,
int gt_height,
int gt_width,
int mask_size,
int sample_ratio) {
// computing sampled_masks
RoIAlignForward(gt_masks, rois, matches, total_out_el,
num_classes, gt_height, gt_width, mask_size, mask_size,
Expand All @@ -221,8 +221,8 @@ __global__ void MRCNNTargetKernel(const DType *rois,
}

template<>
void MRCNNTargetRun<gpu>(const MRCNNTargetParam& param, const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs, mshadow::Stream<gpu> *s) {
void MRCNNMaskTargetRun<gpu>(const MRCNNMaskTargetParam& param, const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs, mshadow::Stream<gpu> *s) {
const int block_dim_size = kMaxThreadsPerBlock;
using namespace mxnet_op;
using mshadow::Tensor;
Expand All @@ -248,31 +248,31 @@ void MRCNNTargetRun<gpu>(const MRCNNTargetParam& param, const std::vector<TBlob>
dim3 dimGrid = dim3(CUDA_GET_BLOCKS(num_el));
dim3 dimBlock = dim3(block_dim_size);

MRCNNTargetKernel<<<dimGrid, dimBlock, 0, stream>>>
MRCNNMaskTargetKernel<<<dimGrid, dimBlock, 0, stream>>>
(rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_,
out_masks.dptr_, out_mask_cls.dptr_,
num_el, batch_size, param.num_classes, param.num_rois,
num_gtmasks, gt_height, gt_width, param.mask_size, param.sample_ratio);
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNTargetKernel);
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNMaskTargetKernel);
});
}

DMLC_REGISTER_PARAMETER(MRCNNTargetParam);
DMLC_REGISTER_PARAMETER(MRCNNMaskTargetParam);

NNVM_REGISTER_OP(mrcnn_target)
NNVM_REGISTER_OP(_contrib_mrcnn_mask_target)
.describe("Generate mask targets for Mask-RCNN.")
.set_num_inputs(4)
.set_num_outputs(2)
.set_attr_parser(ParamParser<MRCNNTargetParam>)
.set_attr<mxnet::FInferShape>("FInferShape", MRCNNTargetShape)
.set_attr<nnvm::FInferType>("FInferType", MRCNNTargetType)
.set_attr<FCompute>("FCompute<gpu>", MRCNNTargetCompute<gpu>)
.set_attr_parser(ParamParser<MRCNNMaskTargetParam>)
.set_attr<mxnet::FInferShape>("FInferShape", MRCNNMaskTargetShape)
.set_attr<nnvm::FInferType>("FInferType", MRCNNMaskTargetType)
.set_attr<FCompute>("FCompute<gpu>", MRCNNMaskTargetCompute<gpu>)
.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 3D array")
.add_argument("gt_masks", "NDArray-or-Symbol", "Input masks of full image size, a 4D array")
.add_argument("matches", "NDArray-or-Symbol", "Index to a gt_mask, a 2D array")
.add_argument("cls_targets", "NDArray-or-Symbol",
"Value [0, num_class), excluding background class, a 2D array")
.add_arguments(MRCNNTargetParam::__FIELDS__());
.add_arguments(MRCNNMaskTargetParam::__FIELDS__());

} // namespace op
} // namespace mxnet
58 changes: 58 additions & 0 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import itertools
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
from common import with_seed
import unittest

def test_box_nms_op():
Expand Down Expand Up @@ -351,6 +352,63 @@ def test_box_decode_op():
assert_allclose(Y.asnumpy(), np.array([[[-0.0562755, -0.00865743, 0.26227552, 0.42465743], \
[0.13240421, 0.17859563, 0.93759584, 1.1174043 ]]]), atol=1e-5, rtol=1e-5)

@with_seed()
def test_op_mrcnn_mask_target():
if default_context().device_type != 'gpu':
return

num_rois = 2
num_classes = 4
mask_size = 3
ctx = mx.gpu(0)
# (B, N, 4)
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
[3.5, 5.5, 0.9, 2.4]]], ctx=ctx)
gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32)

# (B, N)
matches = mx.nd.array([[2, 0]], ctx=ctx)
# (B, N)
cls_targets = mx.nd.array([[2, 1]], ctx=ctx)

mask_targets, mask_cls = mx.nd.contrib.mrcnn_mask_target(rois, gt_masks, matches, cls_targets,
num_rois=num_rois,
num_classes=num_classes,
mask_size=mask_size)

# Ground truth outputs were generated with GluonCV's target generator
# gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size)
gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]]],
[[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]]]]])

gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]])
gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3))

assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy())
assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy())

if __name__ == '__main__':
import nose
nose.runmodule()
57 changes: 0 additions & 57 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8686,63 +8686,6 @@ def test_rroi_align_value(sampling_ratio=-1):
test_rroi_align_value()
test_rroi_align_value(sampling_ratio=2)

@with_seed()
def test_op_mrcnn_target():
if default_context().device_type != 'gpu':
return

num_rois = 2
num_classes = 4
mask_size = 3
ctx = mx.gpu(0)
# (B, N, 4)
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
[3.5, 5.5, 0.9, 2.4]]], ctx=ctx)
gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32)

# (B, N)
matches = mx.nd.array([[2, 0]], ctx=ctx)
# (B, N)
cls_targets = mx.nd.array([[2, 1]], ctx=ctx)

mask_targets, mask_cls = mx.nd.mrcnn_target(rois, gt_masks, matches, cls_targets,
num_rois=num_rois,
num_classes=num_classes,
mask_size=mask_size)

# Ground truth outputs were generated with GluonCV's target generator
# gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size)
gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]],
[[2193.4 , 2193.7332 , 2194.0667 ],
[2204.0667 , 2204.4 , 2204.7334 ],
[2214.7334 , 2215.0667 , 2215.4 ]]],
[[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]],
[[ 185. , 185.33334, 185.66667],
[ 195.66667, 196.00002, 196.33334],
[ 206.33333, 206.66666, 207. ]]]]])

gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]])
gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3))

assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy())
assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy())

@with_seed()
def test_diag():

Expand Down